{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torchtext\n",
    "from torchtext.vocab import Vectors\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "text = torchtext.data.Field(include_lengths = True)\n",
    "label = torchtext.data.Field(sequential=False)\n",
    "train, val, test = torchtext.datasets.SST.splits(text, label, filter_pred=lambda ex: ex.label != 'neutral')\n",
    "text.build_vocab(train)\n",
    "label.build_vocab(train)\n",
    "train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits((train, val, test), batch_size=10, device=-1, repeat = False)\n",
    "url = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec'\n",
    "text.vocab.load_vectors(vectors=Vectors('wiki.simple.vec', url=url))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class CNN_CBoW(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, batch_size):\n",
    "        super(CNN_CBoW, self).__init__()\n",
    "        self.embeddings = nn.Embedding(text.vocab.vectors.size()[0], text.vocab.vectors.size()[1])\n",
    "        self.embeddings.weight.data.copy_(text.vocab.vectors)\n",
    "        self.convs = nn.ModuleList([nn.Conv1d(in_channels = in_channels, out_channels = in_channels, kernel_size = n) for n in (1,2,3,4)])\n",
    "        self.dropout_train, self.dropout_test = nn.Dropout(p = 0.5), nn.Dropout(p = 0)\n",
    "        self.linear = nn.Linear(in_features=in_channels*2 + 1, out_features=out_channels, bias = True)\n",
    "    \n",
    "    def forward(self, x, train = True):\n",
    "        x, lengths = x\n",
    "        lengths = Variable(lengths.view(-1, 1).float())\n",
    "        embedded = self.embeddings(x)\n",
    "        average_embed = embedded.mean(0)\n",
    "        cbow = torch.cat([average_embed, lengths], dim = 1)\n",
    "        \n",
    "        embedded = embedded.transpose(1, 2)\n",
    "        embedded = embedded.transpose(0, 2)\n",
    "        concatted_features = torch.cat([conv(embedded) for conv in self.convs if embedded.size(2) >= conv.kernel_size[0]], dim = 2)\n",
    "        activated_features = nn.functional.relu(concatted_features)\n",
    "        pooled = nn.functional.max_pool1d(activated_features, activated_features.size(2)).squeeze(2)\n",
    "        \n",
    "        ensemble = torch.cat([pooled, cbow], dim = 1)\n",
    "        \n",
    "        dropped = self.dropout_train(ensemble) if train else self.dropout_test(ensemble)\n",
    "        output = self.linear(ensemble)\n",
    "        logits = nn.functional.log_softmax(output, dim = 1)\n",
    "        return logits\n",
    "\n",
    "    def predict(self, x):\n",
    "        logits = self.forward(x, train = False)\n",
    "        return logits.max(1)[1] + 1\n",
    "    \n",
    "    def train(self, train_iter, val_iter, num_epochs, learning_rate = 1e-3):\n",
    "        criterion = nn.NLLLoss()\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)\n",
    "        loss_vec = []\n",
    "        \n",
    "        for epoch in tqdm_notebook(range(1, num_epochs+1)):\n",
    "            epoch_loss = 0\n",
    "            for batch in train_iter:\n",
    "                x = batch.text\n",
    "                y = batch.label\n",
    "                \n",
    "                optimizer.zero_grad()\n",
    "                \n",
    "                y_p = self.forward(x)\n",
    "                \n",
    "                loss = criterion(y_p, y-1)\n",
    "                loss.backward()\n",
    "                \n",
    "                optimizer.step()\n",
    "                epoch_loss += loss.data[0]\n",
    "                \n",
    "            self.model = model\n",
    "            \n",
    "            loss_vec.append(epoch_loss / len(train_iter))\n",
    "            if epoch % 1 == 0:\n",
    "                acc = self.validate(val_iter)\n",
    "                print('Epoch {} loss: {} | acc: {}'.format(epoch, loss_vec[epoch-1], acc))\n",
    "                self.model = model\n",
    "        \n",
    "        plt.plot(range(len(loss_vec)), loss_vec)\n",
    "        plt.xlabel('Epoch')\n",
    "        plt.ylabel('Loss')\n",
    "        plt.show()\n",
    "        print('\\nModel trained.\\n')\n",
    "        self.loss_vec = loss_vec\n",
    "        self.model = model\n",
    "\n",
    "    def test(self, test_iter):\n",
    "        \"All models should be able to be run with following command.\"\n",
    "        upload = []\n",
    "        # Update: for kaggle the bucket iterator needs to have batch_size 10\n",
    "        for batch in test_iter:\n",
    "            # Your prediction data here (don't cheat!)\n",
    "            x, y = batch.text, batch.label\n",
    "            probs = self.predict(x)[:len(y)]\n",
    "            upload += list(probs.data)\n",
    "\n",
    "        with open(\"predictions.txt\", \"w\") as f:\n",
    "            for u in upload:\n",
    "                f.write(str(u) + \"\\n\")\n",
    "                \n",
    "    def validate(self, val_iter):\n",
    "        y_p, y_t, correct = [], [], 0\n",
    "        for batch in val_iter:\n",
    "            x, y = batch.text, batch.label\n",
    "            probs = self.model.predict(x)[:len(y)]\n",
    "            y_p += list(probs.data)\n",
    "             \n",
    "            y_t += list(y.data)\n",
    "        correct = sum([1 if i == j else 0 for i, j in zip(y_p, y_t)])\n",
    "        accuracy = correct / len(y_p)\n",
    "        return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f3d41eb1a5b44c48a08dfe74a7fa698c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sob/Desktop/cs287/homework1/env/lib/python3.6/site-packages/ipykernel_launcher.py:41: DeprecationWarning: generator 'Iterator.__iter__' raised StopIteration\n",
      "/Users/sob/Desktop/cs287/homework1/env/lib/python3.6/site-packages/ipykernel_launcher.py:87: DeprecationWarning: generator 'Iterator.__iter__' raised StopIteration\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 loss: 0.52801511413014 | acc: 0.8038990825688074\n",
      "Epoch 2 loss: 0.21910578757928242 | acc: 0.7889908256880734\n"
     ]
    }
   ],
   "source": [
    "model = CNN_CBoW(in_channels = 300, out_channels = 2, batch_size = 10)\n",
    "model.train(train_iter = train_iter, val_iter = val_iter, num_epochs = 5, learning_rate = 1e-3)\n",
    "model.test(test_iter)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
